{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PyTorch Tutorial\n",
    "MILA, November 2017\n",
    "\n",
    "By Sandeep Subramanian"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Classifying MNIST & CIFAR-10 with Convnets & ResNets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "from __future__ import print_function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.init as init\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torchvision\n",
    "import torchvision.transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define image transformations &  Initialize datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "mnist_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])\n",
    "mnist_train = torchvision.datasets.MNIST(root='./data', train=True, transform=mnist_transforms, download=True)\n",
    "mnist_test = torchvision.datasets.MNIST(root='./data', train=False, transform=mnist_transforms, download=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create multi-threaded DataLoaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "trainloader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True, num_workers=2)\n",
    "testloader = torch.utils.data.DataLoader(mnist_test, batch_size=64, shuffle=True, num_workers=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Main classifier that subclasses nn.Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class Classifier(nn.Module):\n",
    "    \"\"\"Convnet Classifier\"\"\"\n",
    "    def __init__(self):\n",
    "        super(Classifier, self).__init__()\n",
    "        self.conv = nn.Sequential(\n",
    "            # Layer 1\n",
    "            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), padding=1),\n",
    "            nn.Dropout(p=0.5),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=(2, 2), stride=2),\n",
    "            \n",
    "            # Layer 2\n",
    "            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=1),\n",
    "            nn.Dropout(p=0.5),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=(2, 2), stride=2),\n",
    "            \n",
    "            # Layer 3\n",
    "            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), padding=1),\n",
    "            nn.Dropout(p=0.5),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=(2, 2), stride=2),\n",
    "            \n",
    "            # Layer 4\n",
    "            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), padding=1),\n",
    "            nn.Dropout(p=0.5),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=(2, 2), stride=2)\n",
    "        )\n",
    "        # Logistic Regression\n",
    "        self.clf = nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.clf(self.conv(x).squeeze())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cuda_available = torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "clf = Classifier()\n",
    "if cuda_available:\n",
    "    clf = clf.cuda()\n",
    "optimizer = torch.optim.Adam(clf.parameters(), lr=1e-4)\n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 0 Loss : 1.080 \n",
      "Epoch : 0 Test Acc : 90.880\n",
      "--------------------------------------------------------------\n",
      "Epoch : 1 Loss : 0.344 \n",
      "Epoch : 1 Test Acc : 94.320\n",
      "--------------------------------------------------------------\n",
      "Epoch : 2 Loss : 0.255 \n",
      "Epoch : 2 Test Acc : 96.250\n",
      "--------------------------------------------------------------\n",
      "Epoch : 3 Loss : 0.214 \n",
      "Epoch : 3 Test Acc : 96.530\n",
      "--------------------------------------------------------------\n",
      "Epoch : 4 Loss : 0.189 \n",
      "Epoch : 4 Test Acc : 97.180\n",
      "--------------------------------------------------------------\n",
      "Epoch : 5 Loss : 0.171 \n",
      "Epoch : 5 Test Acc : 97.600\n",
      "--------------------------------------------------------------\n",
      "Epoch : 6 Loss : 0.159 \n",
      "Epoch : 6 Test Acc : 97.720\n",
      "--------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Process Process-29:\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/multiprocessing/process.py\", line 258, in _bootstrap\n",
      "Process Process-30:\n",
      "    self.run()\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/multiprocessing/process.py\", line 258, in _bootstrap\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/multiprocessing/process.py\", line 114, in run\n",
      "    self.run()\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/multiprocessing/process.py\", line 114, in run\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/site-packages/torch/utils/data/dataloader.py\", line 45, in _worker_loop\n",
      "    self._target(*self._args, **self._kwargs)\n",
      "    self._target(*self._args, **self._kwargs)\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/site-packages/torch/utils/data/dataloader.py\", line 41, in _worker_loop\n",
      "    samples = collate_fn([dataset[i] for i in batch_indices])\n",
      "    data_queue.put((idx, samples))\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/site-packages/torchvision/datasets/mnist.py\", line 52, in __getitem__\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/multiprocessing/queues.py\", line 392, in put\n",
      "    return send(obj)\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/site-packages/torch/multiprocessing/queue.py\", line 17, in send\n",
      "    ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj)\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/pickle.py\", line 224, in dump\n",
      "    self.save(obj)\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/pickle.py\", line 286, in save\n",
      "    f(self, obj) # Call unbound method with explicit self\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/pickle.py\", line 554, in save_tuple\n",
      "    save(element)\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/pickle.py\", line 286, in save\n",
      "    f(self, obj) # Call unbound method with explicit self\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/pickle.py\", line 454, in save_int\n",
      "    self.write(\"%c%c%c\" % (BININT2, obj&0xff, obj>>8))\n",
      "KeyboardInterrupt\n",
      "    img = Image.fromarray(img.numpy(), mode='L')\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/site-packages/PIL/Image.py\", line 2284, in fromarray\n",
      "    arr = obj.__array_interface__\n",
      "KeyboardInterrupt\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-10-83c0a16910bf>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      2\u001b[0m     \u001b[0mlosses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0;31m# Train\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m     \u001b[0;32mfor\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mcuda_available\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m             \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargets\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/sandeep/anaconda2/lib/python2.7/site-packages/torch/utils/data/dataloader.pyc\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    194\u001b[0m         \u001b[0;32mwhile\u001b[0m \u001b[0mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    195\u001b[0m             \u001b[0;32massert\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshutdown\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatches_outstanding\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 196\u001b[0;31m             \u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_queue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    197\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatches_outstanding\u001b[0m \u001b[0;34m-=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    198\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrcvd_idx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/sandeep/anaconda2/lib/python2.7/multiprocessing/queues.pyc\u001b[0m in \u001b[0;36mget\u001b[0;34m()\u001b[0m\n\u001b[1;32m    376\u001b[0m             \u001b[0mracquire\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    377\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 378\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mrecv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    379\u001b[0m             \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    380\u001b[0m                 \u001b[0mrrelease\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/sandeep/anaconda2/lib/python2.7/site-packages/torch/multiprocessing/queue.pyc\u001b[0m in \u001b[0;36mrecv\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     20\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mrecv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     21\u001b[0m         \u001b[0mbuf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecv_bytes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     24\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__getattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/sandeep/anaconda2/lib/python2.7/pickle.pyc\u001b[0m in \u001b[0;36mloads\u001b[0;34m(str)\u001b[0m\n\u001b[1;32m   1386\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mloads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1387\u001b[0m     \u001b[0mfile\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mStringIO\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1388\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mUnpickler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1389\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1390\u001b[0m \u001b[0;31m# Doctest\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/sandeep/anaconda2/lib/python2.7/pickle.pyc\u001b[0m in \u001b[0;36mload\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    862\u001b[0m             \u001b[0;32mwhile\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    863\u001b[0m                 \u001b[0mkey\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 864\u001b[0;31m                 \u001b[0mdispatch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    865\u001b[0m         \u001b[0;32mexcept\u001b[0m \u001b[0m_Stop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstopinst\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    866\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mstopinst\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/sandeep/anaconda2/lib/python2.7/pickle.pyc\u001b[0m in \u001b[0;36mload_reduce\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1137\u001b[0m         \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1138\u001b[0m         \u001b[0mfunc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstack\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1139\u001b[0;31m         \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1140\u001b[0m         \u001b[0mstack\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1141\u001b[0m     \u001b[0mdispatch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mREDUCE\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_reduce\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/sandeep/anaconda2/lib/python2.7/site-packages/torch/multiprocessing/reductions.pyc\u001b[0m in \u001b[0;36mrebuild_storage_fd\u001b[0;34m(cls, df, size)\u001b[0m\n\u001b[1;32m     66\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrebuild_storage_fd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     67\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mversion_info\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 68\u001b[0;31m         \u001b[0mfd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmultiprocessing\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrebuild_handle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     69\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     70\u001b[0m         \u001b[0mfd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/sandeep/anaconda2/lib/python2.7/multiprocessing/reduction.pyc\u001b[0m in \u001b[0;36mrebuild_handle\u001b[0;34m(pickled_data)\u001b[0m\n\u001b[1;32m    153\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    154\u001b[0m     \u001b[0msub_debug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'rebuilding handle %d'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 155\u001b[0;31m     \u001b[0mconn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mClient\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maddress\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mauthkey\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcurrent_process\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauthkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    156\u001b[0m     \u001b[0mconn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetpid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    157\u001b[0m     \u001b[0mnew_handle\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrecv_handle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/sandeep/anaconda2/lib/python2.7/multiprocessing/connection.pyc\u001b[0m in \u001b[0;36mClient\u001b[0;34m(address, family, authkey)\u001b[0m\n\u001b[1;32m    173\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    174\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mauthkey\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 175\u001b[0;31m         \u001b[0manswer_challenge\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mauthkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    176\u001b[0m         \u001b[0mdeliver_challenge\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mauthkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    177\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/sandeep/anaconda2/lib/python2.7/multiprocessing/connection.pyc\u001b[0m in \u001b[0;36manswer_challenge\u001b[0;34m(connection, authkey)\u001b[0m\n\u001b[1;32m    430\u001b[0m     \u001b[0;32mimport\u001b[0m \u001b[0mhmac\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    431\u001b[0m     \u001b[0;32massert\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mauthkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbytes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 432\u001b[0;31m     \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconnection\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecv_bytes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m256\u001b[0m\u001b[0;34m)\u001b[0m         \u001b[0;31m# reject large message\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    433\u001b[0m     \u001b[0;32massert\u001b[0m \u001b[0mmessage\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mCHALLENGE\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mCHALLENGE\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'message = %r'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mmessage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    434\u001b[0m     \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmessage\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mCHALLENGE\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for epoch in range(50):\n",
    "    losses = []\n",
    "    # Train\n",
    "    for batch_idx, (inputs, targets) in enumerate(trainloader):\n",
    "        if cuda_available:\n",
    "            inputs, targets = inputs.cuda(), targets.cuda()\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        inputs, targets = Variable(inputs), Variable(targets)\n",
    "        outputs = clf(inputs)\n",
    "        loss = criterion(outputs, targets)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        losses.append(loss.data[0])\n",
    "\n",
    "    print('Epoch : %d Loss : %.3f ' % (epoch, np.mean(losses)))\n",
    "    \n",
    "    # Evaluate\n",
    "    clf.eval()\n",
    "    total = 0\n",
    "    correct = 0\n",
    "    for batch_idx, (inputs, targets) in enumerate(testloader):\n",
    "        if cuda_available:\n",
    "            inputs, targets = inputs.cuda(), targets.cuda()\n",
    "\n",
    "        inputs, targets = Variable(inputs, volatile=True), Variable(targets, volatile=True)\n",
    "        outputs = clf(inputs)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += targets.size(0)\n",
    "        correct += predicted.eq(targets.data).cpu().sum()\n",
    "\n",
    "    print('Epoch : %d Test Acc : %.3f' % (epoch, 100.*correct/total))\n",
    "    print('--------------------------------------------------------------')\n",
    "    clf.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "cifar_train_transform = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.RandomCrop(32, padding=4),\n",
    "    torchvision.transforms.RandomHorizontalFlip(),\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "cifar_test_transform = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=cifar_train_transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)\n",
    "\n",
    "testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=cifar_test_transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a single Residual Block\n",
    "\n",
    "Adapted from https://github.com/kuangliu/pytorch-cifar and https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class ResidualBlock(nn.Module):\n",
    "    expansion = 1\n",
    "\n",
    "    def __init__(self, in_channels, out_channels, stride=1):\n",
    "        super(ResidualBlock, self).__init__()\n",
    "        \n",
    "        # Conv Layer 1\n",
    "        self.conv1 = nn.Conv2d(\n",
    "            in_channels=in_channels, out_channels=out_channels,\n",
    "            kernel_size=(3, 3), stride=stride, padding=1, bias=False\n",
    "        )\n",
    "        self.bn1 = nn.BatchNorm2d(out_channels)\n",
    "        \n",
    "        # Conv Layer 2\n",
    "        self.conv2 = nn.Conv2d(\n",
    "            in_channels=out_channels, out_channels=out_channels,\n",
    "            kernel_size=(3, 3), stride=1, padding=1, bias=False\n",
    "        )\n",
    "        self.bn2 = nn.BatchNorm2d(out_channels)\n",
    "    \n",
    "        # Shortcut connection to downsample residual\n",
    "        self.shortcut = nn.Sequential()\n",
    "        if stride != 1 or in_channels != out_channels:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(\n",
    "                    in_channels=in_channels, out_channels=out_channels,\n",
    "                    kernel_size=(1, 1), stride=stride, bias=False\n",
    "                ),\n",
    "                nn.BatchNorm2d(out_channels)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.bn2(self.conv2(out))\n",
    "        out += self.shortcut(x)\n",
    "        out = F.relu(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class CIFARResNet18(nn.Module):\n",
    "    def __init__(self, num_classes=10):\n",
    "        super(CIFARResNet18, self).__init__()\n",
    "        \n",
    "        # Initial input conv\n",
    "        self.conv1 = nn.Conv2d(\n",
    "            in_channels=3, out_channels=64, kernel_size=(3, 3),\n",
    "            stride=1, padding=1, bias=False\n",
    "        )\n",
    "        self.bn1 = nn.BatchNorm2d(64)\n",
    "        \n",
    "        # Create stages 1-4\n",
    "        self.stage1 = self._create_stage(64, 64, stride=1)\n",
    "        self.stage2 = self._create_stage(64, 128, stride=2)\n",
    "        self.stage3 = self._create_stage(128, 256, stride=2)\n",
    "        self.stage4 = self._create_stage(256, 512, stride=2)\n",
    "        self.linear = nn.Linear(512, num_classes)\n",
    "    \n",
    "    # A stage is just two residual blocks for ResNet18\n",
    "    def _create_stage(self, in_channels, out_channels, stride):\n",
    "        return nn.Sequential(\n",
    "            ResidualBlock(in_channels, out_channels, stride),\n",
    "            ResidualBlock(out_channels, out_channels, 1)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.stage1(out)\n",
    "        out = self.stage2(out)\n",
    "        out = self.stage3(out)\n",
    "        out = self.stage4(out)\n",
    "        out = F.avg_pool2d(out, 4)\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.linear(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "clf = CIFARResNet18()\n",
    "if cuda_available:\n",
    "    clf = clf.cuda()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(clf.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)\n",
    "scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 200], gamma=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 0 Loss : 1.631 Time : 35.428 seconds \n",
      "Epoch : 0 Test Acc : 46.180\n",
      "--------------------------------------------------------------\n",
      "Epoch : 1 Loss : 1.120 Time : 36.086 seconds \n",
      "Epoch : 1 Test Acc : 57.270\n",
      "--------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Process Process-40:\n",
      "KeyboardInterrupt\n",
      "Process Process-39:\n",
      "Traceback (most recent call last):\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/multiprocessing/process.py\", line 258, in _bootstrap\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/multiprocessing/process.py\", line 258, in _bootstrap\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-15-410d974e8f8e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     14\u001b[0m         \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m         \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m         \u001b[0mlosses\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     17\u001b[0m     \u001b[0mend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "    self.run()\n",
      "    self.run()\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/multiprocessing/process.py\", line 114, in run\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/multiprocessing/process.py\", line 114, in run\n",
      "    self._target(*self._args, **self._kwargs)\n",
      "    self._target(*self._args, **self._kwargs)\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/site-packages/torch/utils/data/dataloader.py\", line 35, in _worker_loop\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/site-packages/torch/utils/data/dataloader.py\", line 35, in _worker_loop\n",
      "    r = index_queue.get()\n",
      "    r = index_queue.get()\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/multiprocessing/queues.py\", line 376, in get\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/multiprocessing/queues.py\", line 378, in get\n",
      "    racquire()\n",
      "    return recv()\n",
      "KeyboardInterrupt\n",
      "  File \"/home/sandeep/anaconda2/lib/python2.7/site-packages/torch/multiprocessing/queue.py\", line 21, in recv\n",
      "    buf = self.recv_bytes()\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(200):\n",
    "    losses = []\n",
    "    scheduler.step()\n",
    "    # Train\n",
    "    start = time.time()\n",
    "    for batch_idx, (inputs, targets) in enumerate(trainloader):\n",
    "        if cuda_available:\n",
    "            inputs, targets = inputs.cuda(), targets.cuda()\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        inputs, targets = Variable(inputs), Variable(targets)\n",
    "        outputs = clf(inputs)\n",
    "        loss = criterion(outputs, targets)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        losses.append(loss.data[0])\n",
    "    end = time.time()\n",
    "\n",
    "    print('Epoch : %d Loss : %.3f Time : %.3f seconds ' % (epoch, np.mean(losses), end - start))\n",
    "    # Evaluate\n",
    "    clf.eval()\n",
    "    total = 0\n",
    "    correct = 0\n",
    "    for batch_idx, (inputs, targets) in enumerate(testloader):\n",
    "        if cuda_available:\n",
    "            inputs, targets = inputs.cuda(), targets.cuda()\n",
    "\n",
    "        inputs, targets = Variable(inputs, volatile=True), Variable(targets, volatile=True)\n",
    "        outputs = clf(inputs)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += targets.size(0)\n",
    "        correct += predicted.eq(targets.data).cpu().sum()\n",
    "\n",
    "    print('Epoch : %d Test Acc : %.3f' % (epoch, 100.*correct/total))\n",
    "    print('--------------------------------------------------------------')\n",
    "    clf.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
